-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Demote(B)Float16 pass: only keep enabled for PPC. #55486
Conversation
So the backends are now doing the right thing by default? Except PPC? |
@gbaraldi yes. see #55479 (comment) |
230e4a8
to
3c4cc92
Compare
@giordano Can you check this now does the right thing on Grace? |
On this PR I get julia> code_llvm(NTuple{2,BFloat16}; debuginfo=:none) do a, b sqrt(a * a + b * b) end ; Function Signature: var"#7"(Core.BFloat16, Core.BFloat16)
define bfloat @"julia_#7_3558"(bfloat %"a::BFloat16", bfloat %"b::BFloat16") #0 {
top:
%bitcast_coercion = bitcast bfloat %"a::BFloat16" to i16
%0 = zext i16 %bitcast_coercion to i32
%1 = shl nuw i32 %0, 16
%bitcast_coercion2 = bitcast i32 %1 to float
%2 = fmul float %bitcast_coercion2, %bitcast_coercion2
%3 = fcmp ord float %2, 0.000000e+00
br i1 %3, label %L13, label %L32
L13: ; preds = %top
%bitcast_coercion104 = bitcast float %2 to i32
%4 = lshr i32 %bitcast_coercion104, 16
%5 = and i32 %4, 1
%narrow = add i32 %bitcast_coercion104, 32767
%6 = add i32 %narrow, %5
%7 = and i32 %6, -65536
%8 = bitcast i32 %7 to float
br label %L32
L32: ; preds = %L13, %top
%bitcast_coercion28 = phi float [ %8, %L13 ], [ 0x7FF8000000000000, %top ]
%bitcast_coercion11 = bitcast bfloat %"b::BFloat16" to i16
%9 = zext i16 %bitcast_coercion11 to i32
%10 = shl nuw i32 %9, 16
%bitcast_coercion16 = bitcast i32 %10 to float
%11 = fmul float %bitcast_coercion16, %bitcast_coercion16
%12 = fcmp ord float %11, 0.000000e+00
br i1 %12, label %L44, label %L63
L44: ; preds = %L32
%bitcast_coercion84 = bitcast float %11 to i32
%13 = lshr i32 %bitcast_coercion84, 16
%14 = and i32 %13, 1
%narrow126 = add i32 %bitcast_coercion84, 32767
%15 = add i32 %narrow126, %14
%16 = and i32 %15, -65536
%17 = bitcast i32 %16 to float
br label %L63
L63: ; preds = %L44, %L32
%bitcast_coercion35 = phi float [ %17, %L44 ], [ 0x7FF8000000000000, %L32 ]
%18 = fadd float %bitcast_coercion28, %bitcast_coercion35
%19 = fcmp ord float %18, 0.000000e+00
br i1 %19, label %L94, label %L102
L94: ; preds = %L63
%bitcast_coercion64 = bitcast float %18 to i32
%20 = lshr i32 %bitcast_coercion64, 16
%21 = and i32 %20, 1
%narrow127 = add i32 %bitcast_coercion64, 32767
%22 = add i32 %narrow127, %21
%23 = and i32 %22, -65536
%24 = bitcast i32 %23 to float
%25 = fcmp uge float %24, 0.000000e+00
br i1 %25, label %L102, label %L100
L100: ; preds = %L94
call void @j_throw_complex_domainerror_3570(ptr nonnull @"jl_sym#sqrt#3571.jit", float %24) #10
unreachable
L102: ; preds = %L94, %L63
%bitcast_coercion44130 = phi float [ %24, %L94 ], [ 0x7FF8000000000000, %L63 ]
%26 = call float @llvm.sqrt.f32(float %bitcast_coercion44130)
%27 = fcmp ord float %26, 0.000000e+00
br i1 %27, label %L107, label %L126
L107: ; preds = %L102
%bitcast_coercion53 = bitcast float %26 to i32
%28 = lshr i32 %bitcast_coercion53, 16
%29 = and i32 %28, 1
%narrow128 = add nuw nsw i32 %29, 32767
%30 = zext nneg i32 %narrow128 to i64
%31 = zext i32 %bitcast_coercion53 to i64
%32 = add nuw nsw i64 %30, %31
%33 = lshr i64 %32, 16
%34 = trunc i64 %33 to i16
%bitcast_coercion62 = bitcast i16 %34 to bfloat
br label %L126
L126: ; preds = %L107, %L102
%value_phi51 = phi bfloat [ %bitcast_coercion62, %L107 ], [ 0xR7FC0, %L102 ]
ret bfloat %value_phi51
} which is exactly same IR I get on nightly (modulo the gensymed function names). It's not doing native operations in |
Ah, that's because BFloat16s needs to be updated: https://github.com/JuliaMath/BFloat16s.jl/blob/2266cc578d973bbd27fde7fc25a3d6dea0160f80/src/bfloat16.jl#L20-L49 And FWIW, we need to careful about extending that check, because e.g. AArch64 doesn't even support bf16 arithmetic on LLVM 18, only on 19: https://godbolt.org/z/vPhbYWrT4 |
With a Julia build against LLVM 19 + JuliaMath/BFloat16s.jl#77 on macos-aarch64 (m3): julia> code_llvm(NTuple{2,BFloat16}; debuginfo=:none) do a, b sqrt(a * a + b * b) end
define bfloat @"julia_#1_2245"(bfloat %"a::BFloat16", bfloat %"b::BFloat16") #0 {
L9:
%0 = fmul bfloat %"a::BFloat16", %"a::BFloat16"
%1 = fmul bfloat %"b::BFloat16", %"b::BFloat16"
%2 = fadd bfloat %0, %1
%3 = fpext bfloat %2 to float
%4 = call float @llvm.sqrt.f32(float %3)
%5 = fptrunc float %4 to bfloat
ret bfloat %5
} The generated code of course still contains conversions, as like you mentioned there's no scalar bf16 instructions yet: julia> code_native(NTuple{2,BFloat16}; debuginfo=:none) do a, b sqrt(a * a + b * b) end
fmov w8, s0
lsl w8, w8, #16
fmov s0, w8
fmul s0, s0, s0
bfcvt h0, s0
fmov w8, s1
lsl w8, w8, #16
fmov s1, w8
fmul s1, s1, s1
bfcvt h1, s1
fmov w8, s1
lsl w8, w8, #16
fmov s1, w8
fmov w8, s0
lsl w8, w8, #16
fmov s0, w8
fadd s0, s0, s1
bfcvt h0, s0
fmov w8, s0
lsl w8, w8, #16
fmov s0, w8
fsqrt s0, s0
bfcvt h0, s0
ret |
LLVM should handle this properly now for everything but PPC (where BFoat16 isn't supported anyway).
LLVM should handle this properly now for everything but PPC (where BFoat16 isn't supported anyway).
I considered stripping the bf16 bits from the pass, but went for the more conservative change for now in case we discover issues lurking in targets that aren't covered by CI.
Fixes #55479